{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Road Follower - Train Model\n",
    "\n",
    "In this notebook we will train a neural network to take an input image, and output a set of x, y values corresponding to a target.\n",
    "\n",
    "We will be using PyTorch deep learning framework to train ResNet18 neural network architecture model for road follower application."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.optim as optim\n",
    "import torch.nn.functional as F\n",
    "import torchvision\n",
    "import torchvision.datasets as datasets\n",
    "import torchvision.models as models\n",
    "import torchvision.transforms as transforms\n",
    "import glob\n",
    "import PIL.Image\n",
    "import os\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Download and extract data\n",
    "\n",
    "Before you start, you should upload the ``road_following_<Date&Time>.zip`` file that you created in the ``data_collection.ipynb`` notebook on the robot. \n",
    "\n",
    "> If you're training on the JetBot you collected data on, you can skip this!\n",
    "\n",
    "You should then extract this dataset by calling the command below:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!unzip -q road_following.zip"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You should see a folder named ``dataset_all`` appear in the file browser."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create Dataset Instance\n",
    "\n",
    "Here we create a custom ``torch.utils.data.Dataset`` implementation, which implements the ``__len__`` and ``__getitem__`` functions.  This class\n",
    "is responsible for loading images and parsing the x, y values from the image filenames.  Because we implement the ``torch.utils.data.Dataset`` class,\n",
    "we can use all of the torch data utilities :)\n",
    "\n",
    "We hard coded some transformations (like color jitter) into our dataset.  We made random horizontal flips optional (in case you want to follow a non-symmetric path, like a road\n",
    "where we need to 'stay right').  If it doesn't matter whether your robot follows some convention, you could enable flips to augment the dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_x(path):\n",
    "    \"\"\"Gets the x value from the image filename\"\"\"\n",
    "    return (float(int(path[3:6])) - 50.0) / 50.0\n",
    "\n",
    "def get_y(path):\n",
    "    \"\"\"Gets the y value from the image filename\"\"\"\n",
    "    return (float(int(path[7:10])) - 50.0) / 50.0\n",
    "\n",
    "class XYDataset(torch.utils.data.Dataset):\n",
    "    \n",
    "    def __init__(self, directory, random_hflips=False):\n",
    "        self.directory = directory\n",
    "        self.random_hflips = random_hflips\n",
    "        self.image_paths = glob.glob(os.path.join(self.directory, '*.jpg'))\n",
    "        self.color_jitter = transforms.ColorJitter(0.3, 0.3, 0.3, 0.3)\n",
    "    \n",
    "    def __len__(self):\n",
    "        return len(self.image_paths)\n",
    "    \n",
    "    def __getitem__(self, idx):\n",
    "        image_path = self.image_paths[idx]\n",
    "        \n",
    "        image = PIL.Image.open(image_path)\n",
    "        x = float(get_x(os.path.basename(image_path)))\n",
    "        y = float(get_y(os.path.basename(image_path)))\n",
    "        \n",
    "        if float(np.random.rand(1)) > 0.5:\n",
    "            image = transforms.functional.hflip(image)\n",
    "            x = -x\n",
    "        \n",
    "        image = self.color_jitter(image)\n",
    "        image = transforms.functional.resize(image, (224, 224))\n",
    "        image = transforms.functional.to_tensor(image)\n",
    "        image = image.numpy()[::-1].copy()\n",
    "        image = torch.from_numpy(image)\n",
    "        image = transforms.functional.normalize(image, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n",
    "        \n",
    "        return image, torch.tensor([x, y]).float()\n",
    "    \n",
    "dataset = XYDataset('dataset_xy', random_hflips=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Split dataset into train and test sets\n",
    "Once we read dataset, we will split data set in train and test sets. In this example we split train and test a 90%-10%. The test set will be used to verify the accuracy of the model we train."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_percent = 0.1\n",
    "num_test = int(test_percent * len(dataset))\n",
    "train_dataset, test_dataset = torch.utils.data.random_split(dataset, [len(dataset) - num_test, num_test])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create data loaders to load data in batches\n",
    "\n",
    "We use ``DataLoader`` class to load data in batches, shuffle data and allow using multi-subprocesses. In this example we use batch size of 64. Batch size will be based on memory available with your GPU and it can impact accuracy of the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_loader = torch.utils.data.DataLoader(\n",
    "    train_dataset,\n",
    "    batch_size=16,\n",
    "    shuffle=True,\n",
    "    num_workers=4\n",
    ")\n",
    "\n",
    "test_loader = torch.utils.data.DataLoader(\n",
    "    test_dataset,\n",
    "    batch_size=16,\n",
    "    shuffle=True,\n",
    "    num_workers=4\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define Neural Network Model \n",
    "\n",
    "We use ResNet-18 model available on PyTorch TorchVision. \n",
    "\n",
    "In a process called transfer learning, we can repurpose a pre-trained model (trained on millions of images) for a new task that has possibly much less data available.\n",
    "\n",
    "\n",
    "More details on ResNet-18 : https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py\n",
    "\n",
    "More Details on Transfer Learning: https://www.youtube.com/watch?v=yofjFQddwHE "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = models.resnet18(pretrained=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "ResNet model has fully connect (fc) final layer with 512 as ``in_features`` and we will be training for regression thus ``out_features`` as 1\n",
    "\n",
    "Finally, we transfer our model for execution on the GPU"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.fc = torch.nn.Linear(512, 2)\n",
    "device = torch.device('cuda')\n",
    "model = model.to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Train Regression:\n",
    "\n",
    "We train for 50 epochs and save best model if the loss is reduced. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "NUM_EPOCHS = 70\n",
    "BEST_MODEL_PATH = 'best_steering_model_xy.pth'\n",
    "best_loss = 1e9\n",
    "\n",
    "optimizer = optim.Adam(model.parameters())\n",
    "\n",
    "for epoch in range(NUM_EPOCHS):\n",
    "    \n",
    "    model.train()\n",
    "    train_loss = 0.0\n",
    "    for images, labels in iter(train_loader):\n",
    "        images = images.to(device)\n",
    "        labels = labels.to(device)\n",
    "        optimizer.zero_grad()\n",
    "        outputs = model(images)\n",
    "        loss = F.mse_loss(outputs, labels)\n",
    "        train_loss += loss\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "    train_loss /= len(train_loader)\n",
    "    \n",
    "    model.eval()\n",
    "    test_loss = 0.0\n",
    "    for images, labels in iter(test_loader):\n",
    "        images = images.to(device)\n",
    "        labels = labels.to(device)\n",
    "        outputs = model(images)\n",
    "        loss = F.mse_loss(outputs, labels)\n",
    "        test_loss += loss\n",
    "    test_loss /= len(test_loader)\n",
    "    \n",
    "    print('%f, %f' % (train_loss, test_loss))\n",
    "    if test_loss < best_loss:\n",
    "        torch.save(model.state_dict(), BEST_MODEL_PATH)\n",
    "        best_loss = test_loss"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Once the model is trained, it will generate ``best_steering_model_xy.pth`` file which you can use for inferencing in the live demo notebook.\n",
    "\n",
    "If you trained on a different machine other than JetBot, you'll need to upload this to the JetBot to the ``road_following`` example folder."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
